from fastapi import FastAPI from pydantic import BaseModel from fastapi.responses import JSONResponse import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer from openai import OpenAI print("Version ---- DeepSeek Only") app = FastAPI() # ----------------------------- # Request schema # ----------------------------- class ConflictDetectionRequest(BaseModel): Req1: str Req2: str model_choice: str # "GPT-4", "DeepSeek-Reasoner", "Fanar" prompt_type: str # "zero-shot" or "few-shot" api_key: str = None # required only if model_choice == "GPT-4" # ----------------------------- # Prompt builder # ----------------------------- def build_prompt(req1, req2, prompt_type="zero-shot"): if prompt_type == "zero-shot": return f"Do the following sentences contradict each other, Answer with only 'yes' or 'no', no explanation. \n 1.{req1} 2.{req2}" elif prompt_type == "few-shot": examples = ( "Example 1:\n" "Req1: The system shall allow password reset.\n" "Req2: The system shall not allow password reset.\n" "Answer: yes\n\n" "Example 2:\n" "Req1: The system shall support Arabic language.\n" "Req2: The system shall support English language.\n" "Answer: no\n\n" ) return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}" else: return f"Do the following sentences contradict each other, yes or no: 1.{req1} 2.{req2}" # ----------------------------- # Startup: load DeepSeek once # ----------------------------- @app.on_event("startup") def load_models(): print("Loading DeepSeek model into memory...") model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(model_name) app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token app.state.deepseek_model = AutoModelForCausalLM.from_pretrained( model_name, # max_tokens=10, # Limit to very few tokens temperature=0.1, # Low temperature for more deterministic output do_sample=False, # Disable sampling for consistent results torch_dtype="auto", device_map="auto", offload_folder="offload" ) # ----------------------------- # Model handlers # ----------------------------- def run_gpt4(req1, req2, prompt_type, api_key): client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) prompt = build_prompt(req1, req2, prompt_type) completion = client.chat.completions.create( model="openai/gpt-4", messages=[{"role": "user", "content": prompt}], temperature=0.7, max_tokens=512 ) return completion.choices[0].message.content.strip() def run_deepseek(req1, req2, prompt_type): tokenizer = app.state.deepseek_tokenizer model = app.state.deepseek_model prompt = build_prompt(req1, req2, prompt_type) inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True) outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def run_fanar(req1, req2, prompt_type): client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API")) prompt = build_prompt(req1, req2, prompt_type) response = client.chat.completions.create( model="Fanar", messages=[{"role": "user", "content": prompt}] ) return response.choices[0].message.content.strip() # ----------------------------- # API route # ----------------------------- @app.post("/predict") def predict(request: ConflictDetectionRequest): try: if request.model_choice == "GPT-4": if not request.api_key: return JSONResponse({"error": "API key required for GPT-4"}, status_code=400) answer = run_gpt4(request.Req1, request.Req2, request.prompt_type, request.api_key) elif request.model_choice == "DeepSeek-Reasoner": answer = run_deepseek(request.Req1, request.Req2, request.prompt_type) elif request.model_choice == "Fanar": answer = run_fanar(request.Req1, request.Req2, request.prompt_type) else: return JSONResponse({"error": "Invalid model_choice"}, status_code=400) return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500)