Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| import os | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from openai import OpenAI | |
| print("Version ---- DeepSeek Only") | |
| app = FastAPI() | |
| # ----------------------------- | |
| # Request schema | |
| # ----------------------------- | |
| class ConflictDetectionRequest(BaseModel): | |
| Req1: str | |
| Req2: str | |
| model_choice: str # "GPT-4", "DeepSeek-Reasoner", "Fanar" | |
| prompt_type: str # "zero-shot" or "few-shot" | |
| api_key: str = None # required only if model_choice == "GPT-4" | |
| # ----------------------------- | |
| # Prompt builder | |
| # ----------------------------- | |
| def build_prompt(req1, req2, prompt_type="zero-shot"): | |
| if prompt_type == "zero-shot": | |
| return f"Do the following sentences contradict each other, Answer with only 'yes' or 'no', no explanation. \n 1.{req1} 2.{req2}" | |
| elif prompt_type == "few-shot": | |
| examples = ( | |
| "Example 1:\n" | |
| "Req1: The system shall allow password reset.\n" | |
| "Req2: The system shall not allow password reset.\n" | |
| "Answer: yes\n\n" | |
| "Example 2:\n" | |
| "Req1: The system shall support Arabic language.\n" | |
| "Req2: The system shall support English language.\n" | |
| "Answer: no\n\n" | |
| ) | |
| return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}" | |
| else: | |
| return f"Do the following sentences contradict each other, yes or no: 1.{req1} 2.{req2}" | |
| # ----------------------------- | |
| # Startup: load DeepSeek once | |
| # ----------------------------- | |
| def load_models(): | |
| print("Loading DeepSeek model into memory...") | |
| model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" | |
| app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token | |
| app.state.deepseek_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| # max_tokens=10, # Limit to very few tokens | |
| temperature=0.1, # Low temperature for more deterministic output | |
| do_sample=False, # Disable sampling for consistent results | |
| torch_dtype="auto", | |
| device_map="auto", | |
| offload_folder="offload" | |
| ) | |
| # ----------------------------- | |
| # Model handlers | |
| # ----------------------------- | |
| def run_gpt4(req1, req2, prompt_type, api_key): | |
| client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) | |
| prompt = build_prompt(req1, req2, prompt_type) | |
| completion = client.chat.completions.create( | |
| model="openai/gpt-4", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=512 | |
| ) | |
| return completion.choices[0].message.content.strip() | |
| def run_deepseek(req1, req2, prompt_type): | |
| tokenizer = app.state.deepseek_tokenizer | |
| model = app.state.deepseek_model | |
| prompt = build_prompt(req1, req2, prompt_type) | |
| inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True) | |
| outputs = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=256, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def run_fanar(req1, req2, prompt_type): | |
| client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API")) | |
| prompt = build_prompt(req1, req2, prompt_type) | |
| response = client.chat.completions.create( | |
| model="Fanar", | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # ----------------------------- | |
| # API route | |
| # ----------------------------- | |
| def predict(request: ConflictDetectionRequest): | |
| try: | |
| if request.model_choice == "GPT-4": | |
| if not request.api_key: | |
| return JSONResponse({"error": "API key required for GPT-4"}, status_code=400) | |
| answer = run_gpt4(request.Req1, request.Req2, request.prompt_type, request.api_key) | |
| elif request.model_choice == "DeepSeek-Reasoner": | |
| answer = run_deepseek(request.Req1, request.Req2, request.prompt_type) | |
| elif request.model_choice == "Fanar": | |
| answer = run_fanar(request.Req1, request.Req2, request.prompt_type) | |
| else: | |
| return JSONResponse({"error": "Invalid model_choice"}, status_code=400) | |
| return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200) | |
| except Exception as e: | |
| return JSONResponse({"error": str(e)}, status_code=500) | |