mayankpuvvala commited on
Commit
8534bdb
·
verified ·
1 Parent(s): 1fc5ab9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import gradio as gr
6
+ import requests
7
+
8
+ # ========== FASTAPI BACKEND ==========
9
+ app = FastAPI()
10
+
11
+ model_name = "mayankpuvvala/peft_lora_t5_merged_model_pytorch_issues"
12
+
13
+ try:
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
+ tokenizer = AutoTokenizer.from_pretrained("t5-small") # match your model's base
16
+ model.eval()
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device)
19
+ print("✅ Model loaded successfully.")
20
+ except Exception as e:
21
+ print("❌ Model loading error:", e)
22
+ model = None
23
+
24
+ class PromptInput(BaseModel):
25
+ prompt: str
26
+
27
+ @app.post("/generate")
28
+ async def generate_text(data: PromptInput):
29
+ if model is None:
30
+ return {"error": "Model not loaded properly."}
31
+
32
+ prompt = data.prompt.strip()
33
+ if not prompt:
34
+ return {"error": "Empty prompt."}
35
+ if len(prompt.split()) > 150:
36
+ return {"error": "Prompt too long. Limit to ~150 words."}
37
+
38
+ try:
39
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
40
+ with torch.no_grad():
41
+ outputs = model.generate(
42
+ **inputs,
43
+ max_new_tokens=200,
44
+ do_sample=True,
45
+ temperature=0.95,
46
+ eos_token_id=tokenizer.eos_token_id
47
+ )
48
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ return {"generated_text": result}
50
+ except torch.cuda.OutOfMemoryError:
51
+ torch.cuda.empty_cache()
52
+ return {"error": "CUDA out of memory. Try shorter input."}
53
+ except Exception as e:
54
+ return {"error": f"Unexpected error: {str(e)}"}
55
+
56
+ # ========== GRADIO FRONTEND ==========
57
+ def generate_response(prompt):
58
+ # Since app is deployed in same Space, use relative URL
59
+ response = requests.post("http://localhost:8000/generate", json={"prompt": prompt})
60
+ if response.status_code == 200:
61
+ return response.json().get("generated_text", "No output returned.")
62
+ else:
63
+ return response.json().get("error", "Error occurred.")
64
+
65
+ gr.Interface(fn=generate_response, inputs="text", outputs="text").launch()